################################################
# Project: Samuel E. Ross Econ Thesis
#
# File: SSJ Analysis
#
# Before running the following file:
#	
# 	-Designate SamuelERossEconThesis/ as the working directory
#	
################################################

# import relevant packages
using Pkg
using Plots
Pkg.activate("./Software/SSJ.jl")
Pkg.instantiate()
using SSJ

# domestic taylor rule block
function _taylor(; rstar, p, ϕπ)
    π = p - p(-1) # inflation
    r = rstar + ϕπ * π # taylor rule
    return (r,)
end
taylor = simple(_taylor)

# foreign taylor rule block
function _taylor_f(; rstar_f, p_f, ϕπ_f)
    π_f = p_f - p_f(-1) # inflation
    r_f = rstar_f + ϕπ_f * π_f # taylor rule
    return (r_f,)
end
taylor_f = simple(_taylor_f)

# domestic firm block
function _firm(; k, l, x, a, ϑ, ϕ, θ, e, p)
    y = (1-ϕ)*(a + ϑ * k + (1-ϑ) * l) + ϕ * x # production technology
    rk = apply((1-ϕ) * ϑ, log) + y - k # rental rate of capital
    w = apply((1-ϕ) * (1-ϑ), log) + y - l # wage
    ϖ = ϕ^ϕ * ((1-ϕ) * ϑ^ϑ * (1-ϑ)^(1-ϑ))^(1-ϕ) 
    mc = (1-ϕ) * (ϑ * rk + (1-ϑ) * w - a) + ϕ * p - apply(ϖ, log) # marginal cost
    pH = apply(θ/(θ-1), log) + mc # home price
    pH_f = apply(θ/(θ-1), log) + mc - e # foreign price
    π = y + mc - apply(θ - 1, log) # profits
    return y, rk, w, pH, pH_f, π, mc 
end
firm = simple(_firm)

# foreign firm block
function _firm_f(; k_f, l_f, x_f, a_f, ϑ, ϕ, θ, e, p_f)
    y_f = (1-ϕ)*(a_f + ϑ * k_f + (1-ϑ) * l_f) + ϕ * x_f # production technology
    rk_f = apply((1-ϕ) * ϑ, log) + y_f - k_f # rental rate of capital
    w_f = apply((1-ϕ) * (1-ϑ), log) + y_f - l_f # wage
    ϖ = ϕ^ϕ * ((1-ϕ) * ϑ^ϑ * (1-ϑ)^(1-ϑ))^(1-ϕ) 
    mc_f = (1-ϕ) * (ϑ * rk_f + (1-ϑ) * w_f - a_f) + ϕ * p_f - apply(ϖ, log) # marginal cost
    pF_f = apply(θ/(θ-1), log) + mc_f # home price
    pF = apply(θ/(θ-1), log) + mc_f + e  # foreign price
    π_f = y_f + mc_f - apply(θ - 1, log) # profits
    return y_f, rk_f, w_f, pF, pF_f, π_f, mc_f 
end
firm_f = simple(_firm_f)

# domestic household block
function _household(; k, l, rk, π, w, p, pH, pF, r, ξ, σ, ν, δ, κ, γ, θ)
    # de-log vars
    K = apply(k, exp)
    L = apply(l, exp)
    rK = apply(rk, exp)
    W = apply(w, exp)
    P = apply(p, exp)
    Π = apply(π, exp)

    c = (1/σ) * (w - (p + (1/ν)*l)) # consumption
    I = K(+1) - (1 - δ) * K + (κ / 2) * (K(+1) - K)^2 / K # law of motion
    i = apply(I, log) 
    cH = apply(1-γ, log) - γ * ξ - θ * (pH - p) + c # home consumption
    cF = apply(γ, log) + (1-γ) * ξ - θ * (pF - p) + c # foreign consumption
    return c, i, cH, cF
end
household = simple(_household)

# foreign household block
function _household_f(; k_f, l_f, rk_f, π_f, w_f, p_f, pH_f, pF_f, r_f, ξ_f, σ, ν, δ, κ, γ, θ)
    # de-log vars
    K_f = apply(k_f, exp)
    L_f = apply(l_f, exp)
    rK_f = apply(rk_f, exp)
    W_f = apply(w_f, exp)
    P_f = apply(p_f, exp)
    Π_f = apply(π_f, exp)

    c_f = (1/σ) * (w_f - (p_f + (1/ν)*l_f)) # consumption
    I_f = K_f(+1) - (1 - δ) * K_f + (κ / 2) * (K_f(+1) - K_f)^2 / K_f # law of motion
    i_f = apply(I_f, log)
    cH_f = apply(γ, log) + (1-γ) * ξ_f - θ * (pH_f - p_f) + c_f # home consumption
    cF_f = apply(1-γ, log) - γ * ξ_f - θ * (pF_f - p_f) + c_f # foreign consumption
    return c_f, i_f, cH_f, cF_f
end
household_f = simple(_household_f)

# arbitrageur block
function _arbitrageurs(; r, r_f, e, p_f, ω, σe, η)
    D_f = -apply(p_f(-1), exp) * (r(-1) - r_f(-1) - e + e(-1) - η) / (ω * σe^2) # foreign position
    D = -apply(r(-1), exp) * apply(e(-1), exp) * D_f / apply(r_f(-1), exp) # home position
    return D_f, D
end
arbitrageurs = simple(_arbitrageurs)

# domestic market clearing block
function _mkt_clearing(; y, c, c_f, cH, cH_f, cF, cF_f, x, x_f, i, i_f, k, l, rk, π, w, p, pH, pH_f, pF, pF_f, r, D, mc, mc_f, e, θ, γ, β, σ, κ, δ, ϕ, ξ)
    # de-log vars
    Y = apply(y, exp) 
    C = apply(c, exp)
    C_f = apply(c_f, exp)
    CH = apply(cH, exp)
    CH_f = apply(cH_f, exp)
    CF = apply(cF, exp)
    CF_f = apply(cF_f, exp)
    X = apply(x, exp)
    X_f = apply(x_f, exp)
    I = apply(i, exp)
    I_f = apply(i_f, exp)
    K = apply(k, exp)
    L = apply(l, exp)
    rK = apply(rk, exp)
    Π = apply(π, exp)
    W = apply(w, exp)
    P = apply(p, exp)
    PH = apply(pH, exp)
    PH_f = apply(pH_f, exp)
    PF = apply(pF, exp)
    PF_f = apply(pF_f, exp)
    R = apply(r, exp)
    MC = apply(mc, exp)
    MC_f = apply(mc_f, exp)

    # define region outputs
    YH = (θ / (θ - 1)) * (1-γ) * (C + X + I)
    YH_f = (θ / (θ - 1)) * γ * (C_f + X_f + I_f)
    YF_f = (θ / (θ - 1)) * (1-γ) * (C_f + X_f + I_f)
    YF = (θ / (θ - 1)) * γ * (C + X + I)

    goods_mkt = Y - (YH + YH_f) # goods market clearing
    euler = 1 - β * R * (C(+1)/C)^(-σ) * (P / P(+1)) # consumption euler equation clearing
    capital_euler = 1 + κ * (K(+1) - K) / K - β * (C(+1) / C)^(-σ) * (rK(+1) / P(+1) + (1-δ) + κ * (K(+2) - K(+1))/K(+1) + (κ/2) * ((K(+2) - K(+1))/K(+1))^2) # capital euler equation clearing
    walras = P * C + P * I - D(+1) / R - W * L - rK * K + D - Π # walras condition
    NX = apply(e, exp) * PH_f * CH_f - PF * CF # define net exports
    country_budget = NX + D(+1)/R - D # country budget constraint
    prices = (P - ((1-γ) * apply(-γ * ξ, exp) * PH^(1-θ) + γ * apply((1-γ) * ξ, exp) * PF^(1-θ))^(1/(1-θ)))/P # price-setting
    return goods_mkt, euler, capital_euler, walras, country_budget, prices
end
mkt_clearing = simple(_mkt_clearing)

# foreign market clearing block
function _mkt_clearing_f(; y_f, c, c_f, cH, cH_f, cF, cF_f, x, x_f, i, i_f, k_f, l_f, rk_f, π_f, w_f, p_f, pH, pH_f, pF, pF_f, r_f, D_f, mc, mc_f, e, θ, γ, β, σ, κ, δ, ξ_f)
    # de-log vars
    Y_f = apply(y_f, exp)
    C = apply(c, exp)
    C_f = apply(c_f, exp)
    CH = apply(cH, exp)
    CH_f = apply(cH_f, exp)
    CF = apply(cF, exp)
    CF_f = apply(cF_f, exp)
    X = apply(x, exp)
    X_f = apply(x_f, exp)
    I = apply(i, exp)
    I_f = apply(i_f, exp)
    K_f = apply(k_f, exp)
    L_f = apply(l_f, exp)
    rK_f = apply(rk_f, exp)
    Π_f = apply(π_f, exp)
    W_f = apply(w_f, exp)
    P_f = apply(p_f, exp)
    PH = apply(pH, exp)
    PH_f = apply(pH_f, exp)
    PF = apply(pF, exp)
    PF_f = apply(pF_f, exp)
    R_f = apply(r_f, exp)
    MC = apply(mc, exp)
    MC_f = apply(mc_f, exp)

    # define region outputs
    YH = (θ / (θ - 1)) * (1-γ) * (C + X + I)
    YH_f = (θ / (θ - 1)) * γ * (C_f + X_f + I_f)
    YF_f = (θ / (θ - 1)) * (1-γ) * (C_f + X_f + I_f)
    YF = (θ / (θ - 1)) * γ * (C + X + I)

    goods_mkt_f = Y_f - (YF + YF_f) # goods market clearing
    euler_f = 1 - β * R_f * (C_f(+1)/C_f)^(-σ) * (P_f / P_f(+1)) # consumption euler equation clearing
    capital_euler_f = 1 + κ * (K_f(+1) - K_f) / K_f - β * (C_f(+1) / C_f)^(-σ) * (rK_f(+1) / P_f(+1) + (1-δ) + κ * (K_f(+2) - K_f(+1))/K_f(+1) + (κ/2) * ((K_f(+2) - K_f(+1))/K_f(+1))^2) # capital euler equation clearing
    walras_f = P_f * C_f + P_f * I_f - D_f(+1) / R_f - W_f * L_f - rK_f * K_f + D_f - Π_f # walras condition
    NX_f = PF * CF / apply(e, exp) - PH_f * CH_f # define net exports
    country_budget_f = NX_f + D_f(+1)/R_f - D_f # country budget constraint
    prices_f = (P_f - (γ * apply((1-γ) * ξ_f, exp) * PH_f^(1-θ) + (1-γ) * apply(-γ * ξ_f, exp) * PF_f^(1-θ))^(1/(1-θ)))/P_f # price-setting
    return goods_mkt_f, euler_f, capital_euler_f, walras_f, country_budget_f, prices_f
end
mkt_clearing_f = simple(_mkt_clearing_f)

# build model from blocks
model = create_model([firm, firm_f, household, household_f, taylor, taylor_f, arbitrageurs, mkt_clearing, mkt_clearing_f]; name="Model")

# print model details
println(model)
println("Blocks: $(model.blocks)")
println()

# calibration values for steady state
calibration = Dict{String, Float64}(
    "a" => 0.0,
    "a_f" => 0.0,
    "ϑ" => 0.3,
    "ϕ" => 0.5,
    "ξ" => 0.0,
    "ξ_f" => 0.0,
    "ν" => 1.0,
    "δ" => 0.02,
    "ρm" => 0.95,
    "ϕπ" => 2.15,
    "ϕπ_f" => 2.15,
    "ϕe" => 0.0,
    "ϕe_f" => 0.0,
    "σm" => 0.0,
    "ϵm" => 0.0,
    "ϵm_f" => 0.0,
    "ebar" => 0.0,
    "ω" => 1.0,
    "σe" => 1.0,
    "θ" => 1.5,
    "γ" => 0.035,
    "β" => 0.99,
    "σ" => 2.0,
    "κ" => 10.0,
    "p_f" => 0.0,
    "η" => 0.0010
)

# DAG unknowns
unknowns_ss = Dict{String, Float64}(
    "k" => 0.0,
    "l" => 2.0,
    "x" => 0.0,
    "k_f" => 0.0,
    "l_f" => 2.0,
    "x_f" => 0.0,
    "e" => 0.0,
    "p" => 0.0,
    "rstar" => 0.0,
    "rstar_f" => 0.0
)

# DAG targets 
targets_ss = Dict{String, Float64}(
    "goods_mkt" => 0.0,
    "goods_mkt_f" => 0.0,
    "euler" => 0.0,
    "euler_f" => 0.0,
    "capital_euler" => 0.0,
    "capital_euler_f" => 0.0,
    "country_budget" => 0.0,
    "country_budget_f" => 0.0,
    "prices" => 0.0,
    "prices_f" => 0.0
)

# solve for the steady state
ss = solve_steady_state(model, calibration, unknowns_ss, targets_ss; solver="broyden_custom")
println("Steady State: ", ss.toplevel)
println()

# specify unknowns, targets, and exogenous variables for impulse responses
unknowns = ["k", "l", "x", "rstar", "k_f", "l_f", "x_f", "rstar_f", "e", "p"]
targets = ["goods_mkt", "goods_mkt_f", "euler", "euler_f", "capital_euler", "capital_euler_f", "country_budget", "country_budget_f", "prices", "prices_f"]
inputs = ["a", "a_f", "ξ", "ξ_f", "ϵm", "ϵm_f", "η"]

# solve the G jacobian using the SSJ method
G = solve_jacobian(model, ss, unknowns, targets, inputs; T=300)

# specify time truncation and dη shock process
T, impact1, impact2, impact3, ρ = 300, 0.0001, 0.0002, 0.0003, 0.8
dη = zeros(T, 3)
dη[:, 1] = impact1 * ρ.^(collect(0:T-1))
dη[:, 2] = impact2 * ρ.^(collect(0:T-1))
dη[:, 3] = impact3 * ρ.^(collect(0:T-1))

# plot the credibility shocks
plot(100 * dη[1:50, 1] / ss["η"] , label="10% shock", linewidth=2.5)
plot!(100 * dη[1:50, 2] / ss["η"], label="20% shock", linewidth=2.5)
plot!(100 * dη[1:50, 3] / ss["η"], label="30% shock", linewidth=2.5)
title!("Credibility Shocks")
ylabel!("Percent Deviation from Steady State")
xlabel!("Quarter")
savefig("Output/Results/SSJ/im.pdf")

# plot the consumption impulse response
dc = 100 * G["c"]["η"] * dη 
plot(dc[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dc[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dc[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Consumption Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:bottomright)
savefig("Output/Results/SSJ/im_consshock.pdf")

# plot the investment impulse response
di = 100 * G["i"]["η"] * dη 
plot(di[1:50, 1], label="10% shock", linewidth=2.5)
plot!(di[1:50, 2], label="20% shock", linewidth=2.5)
plot!(di[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Investment Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_invshock.pdf")

# plot the output impulse response
dy = 100 * G["y"]["η"] * dη 
plot(dy[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dy[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dy[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Output Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_outshock.pdf")

# plot the foreign consumption impulse response
dc_f = 100 * G["c_f"]["η"] * dη
plot(dc_f[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dc_f[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dc_f[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="For. Consumption Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:bottomright)
savefig("Output/Results/SSJ/im_consshock_f.pdf")

# plot the foreign investment impulse response
di_f = 100 * G["i_f"]["η"] * dη 
plot(di_f[1:50, 1], label="10% shock", linewidth=2.5)
plot!(di_f[1:50, 2], label="20% shock", linewidth=2.5)
plot!(di_f[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="For. Investment Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:bottomright)
savefig("Output/Results/SSJ/im_invshock_f.pdf")

# plot the exchange rate impulse response
de = 100 * G["e"]["η"] * dη
plot(de[1:50, 1], label="10% shock", linewidth=2.5)
plot!(de[1:50, 2], label="20% shock", linewidth=2.5)
plot!(de[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Exchange Rate Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_exshock.pdf")

# plot the capital impulse response
dk = 100 * G["k"]["η"] * dη
plot(dk[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dk[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dk[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Capital Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:bottomright)
savefig("Output/Results/SSJ/im_capshock.pdf")

# plot the price level impulse response
dp = 100 * G["p"]["η"] * dη
plot(dp[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dp[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dp[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Price Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_priceshock.pdf")

# plot the interest rate impulse response
dr = 100 * G["r"]["η"] * dη
plot(dr[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dr[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dr[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Interest Rate Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_rateshock.pdf")

# plot the labor impulse response
dl = 100 * G["l"]["η"] * dη
plot(dl[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dl[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dl[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Labor Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:bottomright)
savefig("Output/Results/SSJ/im_laborshock.pdf")

# plot the intermediate input impulse response
dx = 100 * G["x"]["η"] * dη
plot(dx[1:50, 1], label="10% shock", linewidth=2.5)
plot!(dx[1:50, 2], label="20% shock", linewidth=2.5)
plot!(dx[1:50, 3], label="30% shock", linewidth=2.5)
plot!(title="Intermediate Input Response to Credibility Shocks", ylabel="Percent Deviation from Steady State", xlabel="Quarter", legend=:topright)
savefig("Output/Results/SSJ/im_intermediateshock.pdf")
